/*
 * black_box_reduction.cpp:
 * Logistic regression using our black-box reduction
 * to turnstile l2 point query. JL matrix is used to
 * compress weights. Then an estimate of the weight
 * vector is fed to a Countsketch matrix. Recovery is
 * again done through the CCF02 recovery procedure applied
 * to the Countsketch data structure.
 *
 * Implementation is mostly same as logistic_sketch.h/jl_recovery_sketch.h
 * and logistic_sketch.cpp/jl_recovery_sketch.cpp.
 */

#include "black_box_reduction.h"
#include <iostream>
#include <numeric>
#include "util.h"
#include <math.h>

namespace wmsketch {

BlackBoxReduction::BlackBoxReduction(
		uint32_t log2_width,
		uint32_t depth,
		int32_t seed,
		float lr_init,
		float l2_reg)
  : bias_{0.f},
	lr_init_{lr_init},
	l2_reg_{l2_reg},
	scale_{1.f},
	t_{0},
	depth_{depth},
	hash_fn_(depth, seed),
	hash_buf_(depth, 0),
	point_query_l2_(log2_width, depth, seed) {

		if (log2_width > BlackBoxReduction::MAX_LOG2_WIDTH) {
			throw std::invalid_argument("Invalid sketch width");
		}

		if (lr_init <= 0.) {
			throw std::invalid_argument("Initial learning rate must be positive");
		}

		uint32_t width = 1 << log2_width;
		width_mask_ = width - 1;

		weights_ = (float**) calloc(depth, sizeof(float*));
		weights_[0] = (float*) calloc(width * depth, sizeof(float));
		for (int i = 0; i < depth; i++) {
			weights_[i] = weights_[0] + i * width;
		}
}

BlackBoxReduction::~BlackBoxReduction() {
	free(weights_[0]);
	free(weights_);
}

// Recovery makes use of black box l2 point query
float BlackBoxReduction::get(uint32_t key) {
	return scale_ * point_query_l2_.get(key);
}

// Get the weights corresponding to nonzero coordinates of x
// which are stored in weight_sums_.
void BlackBoxReduction::get_weights(const std::vector<std::pair<uint32_t, float>>& x) {
	uint64_t n = x.size();
	if (hash_buf_.size() < depth_ * n) {
		hash_buf_.resize(depth_ * n);
	}

	weight_sums_.resize(n);
	uint32_t* ph = hash_buf_.data();

	for (int idx = 0; idx < n; idx++) {
		/*
		 * Go through the nonzero entries of R_{x[idx].first}
		 * and take the dot product of it with weights.
		 */
		hash_fn_.hash(ph + idx * depth_, x[idx].first);
		float dot_product = 0;
		for (int i = 0; i < depth_; i++) {
			uint32_t h = hash_buf_[idx * depth_ + i];
			int sgn = (h >> 31) ? +1 : -1;
			float jl_entry = sgn / sqrt(depth_);
			dot_product += jl_entry * weights_[i][h & width_mask_];
		}
		weight_sums_[idx] = dot_product;
	}
}


// Return z^T Rx
float BlackBoxReduction::dot(const std::vector<std::pair<uint32_t, float>>& x) {
	if (x.size() == 0) return 0.f;
	get_weights(x);
	float dot_product = 0.f;
	for (int idx = 0; idx < x.size(); idx++) {
		// get_weights initializes weight_sums_[idx] to z^T R_i.
		// val is equal to x_i.
		float val = x[idx].second;
		dot_product += val * weight_sums_[idx];
	}
	dot_product *= scale_;
	return dot_product;
}

bool BlackBoxReduction::predict(const std::vector<std::pair<uint32_t, float>>& x) {
	float z = dot(x) + bias_;
	return z >= 0.;
}

float BlackBoxReduction::bias() {
	return bias_;
}

float BlackBoxReduction::scale() {
	return scale_;
}

// This is the same update as JL recovery sketch,
// except that we also update Countsketch, using the
// update equations of our black box reduction.
bool BlackBoxReduction::update(const std::vector<std::pair<uint32_t, float>>& x, bool label) {
	if (x.size() == 0) {
		return bias_ >= 0;
	}
	int y = label ? +1 : -1;
	float lr = lr_init_ / (1.f + lr_init_ * l2_reg_ * t_);
	float z = dot(x) + bias_;
	float g = logistic_grad(y * z);
	scale_ *= (1 - lr * l2_reg_);
	float u = lr * y * g / scale_;

	for (int idx = 0; idx < x.size(); idx++) {
		// Update weights corresponding to nonzero entries of R_{idx}.
		float val = x[idx].second;
		for (int i = 0; i <depth_; i++) {
			uint32_t h = hash_buf_[idx * depth_ + i];
			int sgn = (h >> 31) ? +1 : -1;
			weights_[i][h & width_mask_] -= (sgn / sqrt(depth_)) * u * val;
		}

		// Update CountSketch with (x[idx].first, -u * val)
		// Note that Countsketch is going to hold w_hat / scale_.
		point_query_l2_.update(x[idx].first, -1 * u * val);
	}

	bias_ -= lr * y * g;
	t_++;
	return z >= 0;
}

bool BlackBoxReduction::update(std::vector<float>& new_weights, const std::vector<std::pair<uint32_t, float>>& x, bool label) {
	uint64_t n = x.size();
	new_weights.resize(n);
	if (n == 0) {
		return bias_ >= 0;
	}

	int y = label ? +1 : -1;
	float lr = lr_init_ / (1.f + lr_init_ * l2_reg_ * t_);
	float z = dot(x) + bias_;
	float g = logistic_grad(y * z);
	scale_ *= (1 - lr * l2_reg_);
	float u = lr * y * g / scale_;

	// Update new weights based on Countsketch
	// A few steps below, also need to account for gradient update
	// For consistency with logistic_sketch and JLSketch, want to
	// use the estimates from before this update, together with the
	// exact value of the gradient update.
	for (int idx = 0; idx < n; idx++) new_weights[idx] = point_query_l2_.get(idx);

	for (int idx = 0; idx < n; idx++) {

		// Update JL weights
		float val = x[idx].second;
		for (int i = 0; i < depth_; i++) {
			uint32_t h = hash_buf_[idx * depth_ + i];
			int sgn = (h >> 31) ? +1 : -1;
			weights_[i][h & width_mask_] -= (sgn / sqrt(depth_)) * u * val;
		}

		// Update new_weights based on Countsketch
		new_weights[idx] -= u * val;

		// Update Countsketch
		point_query_l2_.update(x[idx].first, -1 * u * val);
	}

	bias_ -= lr * y * g;
	t_++;
	return z >= 0;
}

// Not implemented.
bool BlackBoxReduction::update(uint32_t key, bool label) {
	throw std::logic_error("Not implemented.");
	return false;
}


}
